import gzip
import io
import json
import os
import shutil
import tempfile
import time
import uuid

import pytest

from tests.test_client_calls import texts_helium1, texts_helium2, texts_helium3, texts_helium4, texts_helium5, \
    texts_simple, texts_long
from tests.utils import wrap_test_forked, kill_weaviate, make_user_path_test
from src.enums import DocumentSubset, LangChainAction, LangChainMode, LangChainTypes, DocumentChoice, \
    docs_joiner_default, docs_token_handling_default, db_types, db_types_full
from src.utils import zip_data, download_simple, get_ngpus_vis, get_mem_gpus, have_faiss, remove, get_kwargs, \
    FakeTokenizer, get_token_count, flatten_list, tar_data
from src.gpt_langchain import get_persist_directory, get_db, get_documents, length_db1, _run_qa_db, split_merge_docs, \
    get_hyde_acc

have_openai_key = os.environ.get('OPENAI_API_KEY') is not None
have_replicate_key = os.environ.get('REPLICATE_API_TOKEN') is not None

have_gpus = get_ngpus_vis() > 0

mem_gpus = get_mem_gpus()

# FIXME:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'


@pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run")
@wrap_test_forked
def test_qa_wiki_openai():
    return run_qa_wiki_fork(use_openai_model=True)


@pytest.mark.need_gpu
@wrap_test_forked
def test_qa_wiki_stuff_hf():
    # NOTE: total context length makes things fail when n_sources * text_limit >~ 2048
    return run_qa_wiki_fork(use_openai_model=False, text_limit=256, chain_type='stuff', prompt_type='human_bot')


@pytest.mark.xfail(strict=False,
                   reason="Too long context, improve prompt for map_reduce.  Until then hit: The size of tensor a (2048) must match the size of tensor b (2125) at non-singleton dimension 3")
@wrap_test_forked
def test_qa_wiki_map_reduce_hf():
    return run_qa_wiki_fork(use_openai_model=False, text_limit=None, chain_type='map_reduce', prompt_type='human_bot')


def run_qa_wiki_fork(*args, **kwargs):
    # disable fork to avoid
    # RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
    # because some other tests use cuda in parent
    # from tests.utils import call_subprocess_onetask
    # return call_subprocess_onetask(run_qa_wiki, args=args, kwargs=kwargs)
    return run_qa_wiki(*args, **kwargs)


def run_qa_wiki(use_openai_model=False, first_para=True, text_limit=None, chain_type='stuff', prompt_type=None):
    from src.gpt_langchain import get_wiki_sources, get_llm
    from langchain.chains.qa_with_sources import load_qa_with_sources_chain

    sources = get_wiki_sources(first_para=first_para, text_limit=text_limit)
    llm, model_name, streamer, prompt_type_out, async_output, only_new_text, gradio_server = \
        get_llm(use_openai_model=use_openai_model, prompt_type=prompt_type, llamacpp_dict={},
                exllama_dict={})
    chain = load_qa_with_sources_chain(llm, chain_type=chain_type)

    question = "What are the main differences between Linux and Windows?"
    from src.gpt_langchain import get_answer_from_sources
    answer = get_answer_from_sources(chain, sources, question)
    print(answer)


def check_ret(ret):
    """
    check generator
    :param ret:
    :return:
    """
    rets = []
    for ret1 in ret:
        rets.append(ret1)
        print(ret1)
    assert rets
    return rets


@pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run")
@wrap_test_forked
def test_qa_wiki_db_openai():
    from src.gpt_langchain import _run_qa_db
    query = "What are the main differences between Linux and Windows?"
    langchain_mode = 'wiki'
    ret = _run_qa_db(query=query, use_openai_model=True, use_openai_embedding=True, text_limit=None,
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     db_type='faiss',
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value, langchain_agents=[], llamacpp_dict={})
    check_ret(ret)


@pytest.mark.need_gpu
@wrap_test_forked
def test_qa_wiki_db_hf():
    from src.gpt_langchain import _run_qa_db
    # if don't chunk, still need to limit
    # but this case can handle at least more documents, by picking top k
    # FIXME: but spitting out garbage answer right now, all fragmented, or just 1-word answer
    query = "What are the main differences between Linux and Windows?"
    langchain_mode = 'wiki'
    ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False, text_limit=256,
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     db_type='faiss',
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value,
                     langchain_agents=[], llamacpp_dict={})
    check_ret(ret)


@pytest.mark.need_gpu
@wrap_test_forked
def test_qa_wiki_db_chunk_hf():
    from src.gpt_langchain import _run_qa_db
    query = "What are the main differences between Linux and Windows?"
    langchain_mode = 'wiki'
    ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False, text_limit=256, chunk=True,
                     chunk_size=256,
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     db_type='faiss',
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value,
                     langchain_agents=[], llamacpp_dict={})
    check_ret(ret)


@pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run")
@wrap_test_forked
def test_qa_wiki_db_chunk_openai():
    from src.gpt_langchain import _run_qa_db
    # don't need 256, just seeing how compares to hf
    query = "What are the main differences between Linux and Windows?"
    langchain_mode = 'wiki'
    ret = _run_qa_db(query=query, use_openai_model=True, use_openai_embedding=True, text_limit=256, chunk=True,
                     chunk_size=256,
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     db_type='faiss',
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value,
                     langchain_agents=[], llamacpp_dict={})
    check_ret(ret)


@pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run")
@wrap_test_forked
def test_qa_github_db_chunk_openai():
    from src.gpt_langchain import _run_qa_db
    # don't need 256, just seeing how compares to hf
    query = "what is a software defined asset"
    langchain_mode = 'github h2oGPT'
    ret = _run_qa_db(query=query, use_openai_model=True, use_openai_embedding=True, text_limit=256, chunk=True,
                     chunk_size=256,
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     db_type='faiss',
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value,
                     langchain_agents=[], llamacpp_dict={})
    check_ret(ret)


@pytest.mark.need_gpu
@wrap_test_forked
def test_qa_daidocs_db_chunk_hf():
    from src.gpt_langchain import _run_qa_db
    # FIXME: doesn't work well with non-instruct-tuned Cerebras
    query = "Which config.toml enables pytorch for NLP?"
    langchain_mode = 'DriverlessAI docs'
    ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False, text_limit=None, chunk=True,
                     chunk_size=128,
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     db_type='faiss',
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value,
                     langchain_agents=[], llamacpp_dict={})
    check_ret(ret)


@pytest.mark.skipif(not have_faiss, reason="requires FAISS")
@wrap_test_forked
def test_qa_daidocs_db_chunk_hf_faiss():
    from src.gpt_langchain import _run_qa_db
    query = "Which config.toml enables pytorch for NLP?"
    # chunk_size is chars for each of k=4 chunks
    langchain_mode = 'DriverlessAI docs'
    ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False, text_limit=None, chunk=True,
                     chunk_size=128 * 1,  # characters, and if k=4, then 4*4*128 = 2048 chars ~ 512 tokens
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value,
                     langchain_agents=[],
                     llamacpp_dict={},
                     db_type='faiss',
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     )
    check_ret(ret)


@pytest.mark.need_gpu
@pytest.mark.parametrize("db_type", db_types)
@pytest.mark.parametrize("top_k_docs", [-1, 3])
@wrap_test_forked
def test_qa_daidocs_db_chunk_hf_dbs(db_type, top_k_docs):
    kill_weaviate(db_type)
    langchain_mode = 'DriverlessAI docs'
    langchain_action = LangChainAction.QUERY.value
    langchain_agents = []
    persist_directory, langchain_type = get_persist_directory(langchain_mode,
                                                              langchain_type=LangChainTypes.SHARED.value)
    assert langchain_type == LangChainTypes.SHARED.value
    remove(persist_directory)
    from src.gpt_langchain import _run_qa_db
    query = "Which config.toml enables pytorch for NLP?"
    # chunk_size is chars for each of k=4 chunks
    if top_k_docs == -1:
        # else OOMs on generation immediately when generation starts, even though only 1600 tokens and 256 new tokens
        model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
    else:
        model_name = None
    ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False, text_limit=None, chunk=True,
                     chunk_size=128 * 1,  # characters, and if k=4, then 4*4*128 = 2048 chars ~ 512 tokens
                     langchain_mode=langchain_mode,
                     langchain_action=langchain_action,
                     langchain_agents=langchain_agents,
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     db_type=db_type,
                     top_k_docs=top_k_docs,
                     model_name=model_name,
                     llamacpp_dict={},
                     )
    check_ret(ret)
    kill_weaviate(db_type)


def get_test_model(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b',
                   tokenizer_base_model='',
                   prompt_type='human_bot',
                   inference_server='',
                   max_seq_len=None,
                   regenerate_clients=True):
    # need to get model externally, so don't OOM
    from src.gen import get_model
    all_kwargs = dict(load_8bit=False,
                      load_4bit=False,
                      low_bit_mode=1,
                      load_half=True,
                      load_gptq='',
                      use_autogptq=False,
                      load_awq='',
                      load_exllama=False,
                      use_safetensors=False,
                      revision=None,
                      use_gpu_id=True,
                      base_model=base_model,
                      tokenizer_base_model=tokenizer_base_model,
                      inference_server=inference_server,
                      regenerate_clients=regenerate_clients,
                      lora_weights='',
                      gpu_id=0,
                      n_jobs=1,
                      n_gpus=None,

                      reward_type=False,
                      local_files_only=False,
                      resume_download=True,
                      use_auth_token=False,
                      trust_remote_code=True,
                      offload_folder=None,
                      rope_scaling=None,
                      max_seq_len=max_seq_len,
                      compile_model=True,
                      llamacpp_dict={},
                      exllama_dict={},
                      gptq_dict={},
                      attention_sinks=False,
                      sink_dict={},
                      truncation_generation=False,
                      hf_model_dict={},
                      use_flash_attention_2=False,
                      llamacpp_path='llamacpp_path',
                      regenerate_gradio_clients=True,
                      max_output_seq_len=None,
                      force_seq2seq_type=False,
                      force_t5_type=False,

                      verbose=False)
    from src.gen import get_model_retry
    model, tokenizer, device = get_model_retry(reward_type=False,
                                               **get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs))
    return model, tokenizer, base_model, prompt_type


@pytest.mark.need_gpu
@pytest.mark.parametrize("db_type", ['chroma'])
@wrap_test_forked
def test_qa_daidocs_db_chunk_hf_dbs_switch_embedding(db_type):
    model, tokenizer, base_model, prompt_type = get_test_model()

    langchain_mode = 'DriverlessAI docs'
    langchain_action = LangChainAction.QUERY.value
    langchain_agents = []
    persist_directory, langchain_type = get_persist_directory(langchain_mode,
                                                              langchain_type=LangChainTypes.SHARED.value)
    assert langchain_type == LangChainTypes.SHARED.value
    remove(persist_directory)
    from src.gpt_langchain import _run_qa_db
    query = "Which config.toml enables pytorch for NLP?"
    # chunk_size is chars for each of k=4 chunks
    ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False,
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     migrate_embedding_model=True,
                     model=model,
                     tokenizer=tokenizer,
                     model_name=base_model,
                     prompt_type=prompt_type,
                     text_limit=None, chunk=True,
                     chunk_size=128 * 1,  # characters, and if k=4, then 4*4*128 = 2048 chars ~ 512 tokens
                     langchain_mode=langchain_mode,
                     langchain_action=langchain_action,
                     langchain_agents=langchain_agents,
                     db_type=db_type,
                     llamacpp_dict={},
                     )
    check_ret(ret)

    query = "Which config.toml enables pytorch for NLP?"
    # chunk_size is chars for each of k=4 chunks
    ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False,
                     hf_embedding_model='BAAI/bge-large-en-v1.5',
                     migrate_embedding_model=True,
                     model=model,
                     tokenizer=tokenizer,
                     model_name=base_model,
                     prompt_type=prompt_type,
                     text_limit=None, chunk=True,
                     chunk_size=128 * 1,  # characters, and if k=4, then 4*4*128 = 2048 chars ~ 512 tokens
                     langchain_mode=langchain_mode,
                     langchain_action=langchain_action,
                     langchain_agents=langchain_agents,
                     db_type=db_type,
                     llamacpp_dict={},
                     )
    check_ret(ret)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_qa_wiki_db_chunk_hf_dbs_llama(db_type):
    kill_weaviate(db_type)
    from src.gpt4all_llm import get_model_tokenizer_gpt4all
    model_name = 'llama'
    model, tokenizer, device = get_model_tokenizer_gpt4all(model_name,
                                                           n_jobs=8,
                                                           max_seq_len=512,
                                                           llamacpp_dict=dict(
                                                               model_path_llama='https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q6_K.gguf?download=true',
                                                               n_gpu_layers=100,
                                                               use_mlock=True,
                                                               n_batch=1024))

    from src.gpt_langchain import _run_qa_db
    query = "What are the main differences between Linux and Windows?"
    # chunk_size is chars for each of k=4 chunks
    langchain_mode = 'wiki'
    ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False, text_limit=None, chunk=True,
                     chunk_size=128 * 1,  # characters, and if k=4, then 4*4*128 = 2048 chars ~ 512 tokens
                     hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value,
                     langchain_agents=[],
                     db_type=db_type,
                     prompt_type='llama2',
                     langchain_only_model=True,
                     model_name=model_name, model=model, tokenizer=tokenizer,
                     llamacpp_dict=dict(n_gpu_layers=100, use_mlock=True, n_batch=1024),
                     )
    check_ret(ret)
    kill_weaviate(db_type)


@pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run")
@wrap_test_forked
def test_qa_daidocs_db_chunk_openai():
    from src.gpt_langchain import _run_qa_db
    query = "Which config.toml enables pytorch for NLP?"
    langchain_mode = 'DriverlessAI docs'
    ret = _run_qa_db(query=query, use_openai_model=True, use_openai_embedding=True, text_limit=256, chunk=True,
                     db_type='faiss',
                     hf_embedding_model="",
                     chunk_size=256,
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value,
                     langchain_agents=[], llamacpp_dict={})
    check_ret(ret)


@pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run")
@wrap_test_forked
def test_qa_daidocs_db_chunk_openaiembedding_hfmodel():
    from src.gpt_langchain import _run_qa_db
    query = "Which config.toml enables pytorch for NLP?"
    langchain_mode = 'DriverlessAI docs'
    ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=True, text_limit=None, chunk=True,
                     chunk_size=128,
                     hf_embedding_model="",
                     db_type='faiss',
                     langchain_mode_types=dict(langchain_mode=LangChainTypes.SHARED.value),
                     langchain_mode=langchain_mode,
                     langchain_action=LangChainAction.QUERY.value,
                     langchain_agents=[], llamacpp_dict={})
    check_ret(ret)


@pytest.mark.need_tokens
@wrap_test_forked
def test_get_dai_pickle():
    from src.gpt_langchain import get_dai_pickle
    with tempfile.TemporaryDirectory() as tmpdirname:
        get_dai_pickle(dest=tmpdirname)
        assert os.path.isfile(os.path.join(tmpdirname, 'dai_docs.pickle'))


@pytest.mark.need_tokens
@wrap_test_forked
def test_get_dai_db_dir():
    from src.gpt_langchain import get_some_dbs_from_hf
    with tempfile.TemporaryDirectory() as tmpdirname:
        get_some_dbs_from_hf(tmpdirname)


# repeat is to check if first case really deletes, else assert will fail if accumulates wrongly
@pytest.mark.parametrize("repeat", [0, 1])
@pytest.mark.parametrize("db_type", db_types_full)
@wrap_test_forked
def test_make_add_db(repeat, db_type):
    kill_weaviate(db_type)
    from src.gpt_langchain import get_source_files, get_source_files_given_langchain_mode, get_any_db, update_user_db, \
        get_sources, update_and_get_source_files_given_langchain_mode
    from src.make_db import make_db_main
    from src.gpt_langchain import path_to_docs
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            with tempfile.TemporaryDirectory() as tmp_persist_directory_my:
                with tempfile.TemporaryDirectory() as tmp_user_path_my:
                    msg1 = "Hello World"
                    test_file1 = os.path.join(tmp_user_path, 'test.txt')
                    with open(test_file1, "wt") as f:
                        f.write(msg1)
                    chunk = True
                    chunk_size = 512
                    langchain_mode = 'UserData'
                    db, collection_name = make_db_main(persist_directory=tmp_persist_directory,
                                                       user_path=tmp_user_path,
                                                       add_if_exists=False,
                                                       collection_name=langchain_mode,
                                                       fail_any_exception=True, db_type=db_type)
                    assert db is not None
                    docs = db.similarity_search("World")
                    assert len(docs) >= 1
                    assert docs[0].page_content == msg1
                    assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)

                    test_file1my = os.path.join(tmp_user_path_my, 'test.txt')
                    with open(test_file1my, "wt") as f:
                        f.write(msg1)
                    dbmy, collection_namemy = make_db_main(persist_directory=tmp_persist_directory_my,
                                                           user_path=tmp_user_path_my,
                                                           add_if_exists=False,
                                                           collection_name='MyData',
                                                           fail_any_exception=True, db_type=db_type)
                    db1 = {LangChainMode.MY_DATA.value: [dbmy, 'foouuid', 'foousername']}
                    assert dbmy is not None
                    docs1 = dbmy.similarity_search("World")
                    assert len(docs1) == 1 + (1 if db_type == 'chroma' else 0)
                    assert docs1[0].page_content == msg1
                    assert os.path.normpath(docs1[0].metadata['source']) == os.path.normpath(test_file1my)

                    # some db testing for gradio UI/client
                    get_source_files(db=db)
                    get_source_files(db=dbmy)
                    selection_docs_state1 = dict(langchain_modes=[langchain_mode], langchain_mode_paths={},
                                                 langchain_mode_types={})
                    requests_state1 = dict()
                    get_source_files_given_langchain_mode(db1, selection_docs_state1, requests_state1, None,
                                                          langchain_mode, dbs={langchain_mode: db})
                    get_source_files_given_langchain_mode(db1, selection_docs_state1, requests_state1, None,
                                                          langchain_mode='MyData', dbs={})
                    get_any_db(db1, langchain_mode='UserData',
                               langchain_mode_paths=selection_docs_state1['langchain_mode_paths'],
                               langchain_mode_types=selection_docs_state1['langchain_mode_types'],
                               dbs={langchain_mode: db})
                    get_any_db(db1, langchain_mode='MyData',
                               langchain_mode_paths=selection_docs_state1['langchain_mode_paths'],
                               langchain_mode_types=selection_docs_state1['langchain_mode_types'],
                               dbs={})

                    msg1up = "Beefy Chicken"
                    test_file2 = os.path.join(tmp_user_path, 'test2.txt')
                    with open(test_file2, "wt") as f:
                        f.write(msg1up)
                    test_file2_my = os.path.join(tmp_user_path_my, 'test2my.txt')
                    with open(test_file2_my, "wt") as f:
                        f.write(msg1up)
                    kwargs = dict(use_openai_embedding=False,
                                  hf_embedding_model='BAAI/bge-large-en-v1.5',
                                  migrate_embedding_model=True,
                                  caption_loader=False,
                                  doctr_loader=False,
                                  asr_loader=False,
                                  enable_captions=False,
                                  enable_doctr=False,
                                  enable_pix2struct=False,
                                  enable_llava=False,
                                  enable_transcriptions=False,
                                  captions_model="microsoft/Florence-2-base",
                                  llava_model=None,
                                  llava_prompt=None,
                                  asr_model='openai/whisper-medium',
                                  enable_ocr=False,
                                  enable_pdf_ocr='auto',
                                  enable_pdf_doctr=False,
                                  gradio_upload_to_chatbot_num_max=1,
                                  verbose=False,
                                  is_url=False, is_txt=False,
                                  allow_upload_to_my_data=True,
                                  allow_upload_to_user_data=True,
                                  )
                    langchain_mode2 = 'MyData'
                    selection_docs_state2 = dict(langchain_modes=[langchain_mode2],
                                                 langchain_mode_paths={},
                                                 langchain_mode_types={})
                    requests_state2 = dict()
                    z1, z2, source_files_added, exceptions, last_file, last_dict = update_user_db(test_file2_my, db1,
                                                                                                  selection_docs_state2,
                                                                                                  requests_state2,
                                                                                                  langchain_mode2,
                                                                                                  chunk=chunk,
                                                                                                  chunk_size=chunk_size,
                                                                                                  dbs={},
                                                                                                  db_type=db_type,
                                                                                                  **kwargs)
                    assert z1 is None
                    assert 'MyData' == z2
                    assert 'test2my' in str(source_files_added)
                    assert len(exceptions) == 0

                    langchain_mode = 'UserData'
                    selection_docs_state1 = dict(langchain_modes=[langchain_mode],
                                                 langchain_mode_paths={langchain_mode: tmp_user_path},
                                                 langchain_mode_types={langchain_mode: LangChainTypes.SHARED.value})
                    z1, z2, source_files_added, exceptions, last_file, last_dict = update_user_db(test_file2, db1,
                                                                                                  selection_docs_state1,
                                                                                                  requests_state1,
                                                                                                  langchain_mode,
                                                                                                  chunk=chunk,
                                                                                                  chunk_size=chunk_size,
                                                                                                  dbs={
                                                                                                      langchain_mode: db},
                                                                                                  db_type=db_type,
                                                                                                  **kwargs)
                    assert 'test2' in str(source_files_added)
                    assert langchain_mode == z2
                    assert z1 is None
                    docs_state0 = [x.name for x in list(DocumentSubset)]
                    get_sources(db1, selection_docs_state1, {}, langchain_mode, dbs={langchain_mode: db},
                                docs_state0=docs_state0)
                    get_sources(db1, selection_docs_state1, {}, 'MyData', dbs={}, docs_state0=docs_state0)
                    selection_docs_state1['langchain_mode_paths'] = {langchain_mode: tmp_user_path}
                    kwargs2 = dict(first_para=False,
                                   text_limit=None, chunk=chunk, chunk_size=chunk_size,
                                   db_type=db_type,
                                   hf_embedding_model=kwargs['hf_embedding_model'],
                                   migrate_embedding_model=kwargs['migrate_embedding_model'],
                                   load_db_if_exists=True,
                                   n_jobs=-1, verbose=False)
                    update_and_get_source_files_given_langchain_mode(db1,
                                                                     selection_docs_state1, requests_state1,
                                                                     langchain_mode, dbs={langchain_mode: db},
                                                                     **kwargs2)
                    update_and_get_source_files_given_langchain_mode(db1,
                                                                     selection_docs_state2, requests_state2,
                                                                     'MyData', dbs={}, **kwargs2)

                    assert path_to_docs(test_file2_my, db_type=db_type)[0].metadata['source'] == test_file2_my
                    extra = 1 if db_type == 'chroma' else 0
                    assert os.path.normpath(
                        path_to_docs(os.path.dirname(test_file2_my), db_type=db_type)[1 + extra].metadata[
                            'source']) == os.path.normpath(
                        os.path.abspath(test_file2_my))
                    assert path_to_docs([test_file1, test_file2, test_file2_my], db_type=db_type)[0].metadata[
                               'source'] == test_file1

                    assert path_to_docs(None, url='arxiv:1706.03762', db_type=db_type)[0].metadata[
                               'source'] == 'http://arxiv.org/abs/1706.03762v7'
                    assert path_to_docs(None, url='http://h2o.ai', db_type=db_type)[0].metadata[
                               'source'] == 'http://h2o.ai'

                    assert 'user_paste' in path_to_docs(None,
                                                        text='Yufuu is a wonderful place and you should really visit because there is lots of sun.',
                                                        db_type=db_type)[0].metadata['source']

                if db_type == 'faiss':
                    # doesn't persist
                    return

                # now add using new source path, to original persisted
                with tempfile.TemporaryDirectory() as tmp_user_path3:
                    msg2 = "Jill ran up the hill"
                    test_file2 = os.path.join(tmp_user_path3, 'test2.txt')
                    with open(test_file2, "wt") as f:
                        f.write(msg2)
                    db, collection_name = make_db_main(persist_directory=tmp_persist_directory,
                                                       user_path=tmp_user_path3,
                                                       add_if_exists=True,
                                                       fail_any_exception=True, db_type=db_type,
                                                       collection_name=collection_name)
                    assert db is not None
                    docs = db.similarity_search("World")
                    assert len(docs) >= 1
                    assert docs[0].page_content == msg1
                    assert docs[1 + extra].page_content in [msg2, msg1up]
                    assert docs[2 + extra].page_content in [msg2, msg1up]
                    assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)

                    docs = db.similarity_search("Jill")
                    assert len(docs) >= 1
                    assert docs[0].page_content == msg2
                    assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file2)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_zip_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            msg1 = "Hello World"
            test_file1 = os.path.join(tmp_user_path, 'test.txt')
            with open(test_file1, "wt") as f:
                f.write(msg1)
            zip_file = './tmpdata/data.zip'
            zip_data(tmp_user_path, zip_file=zip_file, fail_any_exception=True)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("World")
            assert len(docs) >= 1
            assert docs[0].page_content == msg1
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@pytest.mark.parametrize("tar_type", ["tar.gz", "tgz"])
@wrap_test_forked
def test_tar_add(db_type, tar_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            msg1 = "Hello World"
            test_file1 = os.path.join(tmp_user_path, 'test.txt')
            with open(test_file1, "wt") as f:
                f.write(msg1)
            tar_file = f'./tmpdata/data.{tar_type}'
            tar_data(tmp_user_path, tar_file=tar_file, fail_any_exception=True)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("World")
            assert len(docs) >= 1
            assert docs[0].page_content == msg1
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_url_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        url = 'https://h2o.ai/company/team/leadership-team/'
        db, collection_name = make_db_main(persist_directory=tmp_persist_directory, url=url, fail_any_exception=True,
                                           db_type=db_type)
        assert db is not None
        docs = db.similarity_search("list founding team of h2o.ai")
        assert len(docs) >= 1
        assert 'Sri Ambati' in docs[0].page_content
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_urls_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        urls = ['https://h2o.ai/company/team/leadership-team/',
                'https://arxiv.org/abs/1706.03762',
                'https://github.com/h2oai/h2ogpt',
                'https://h2o.ai'
                ]

        db, collection_name = make_db_main(persist_directory=tmp_persist_directory, url=urls,
                                           fail_any_exception=True,
                                           db_type=db_type)
        assert db is not None
        if db_type == 'chroma':
            assert len(db.get()['documents']) > 48
        docs = db.similarity_search("list founding team of h2o.ai")
        assert len(docs) >= 1
        assert 'Sri Ambati' in docs[0].page_content
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_urls_file_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            urls = ['https://h2o.ai/company/team/leadership-team/',
                    'https://arxiv.org/abs/1706.03762',
                    'https://github.com/h2oai/h2ogpt',
                    'https://h2o.ai'
                    ]
            with open(os.path.join(tmp_user_path, 'list.urls'), 'wt') as f:
                f.write('\n'.join(urls))

            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, url=urls,
                                               user_path=tmp_user_path,
                                               fail_any_exception=True,
                                               db_type=db_type)
            assert db is not None
            if db_type == 'chroma':
                assert len(db.get()['documents']) > 45
            docs = db.similarity_search("list founding team of h2o.ai")
            assert len(docs) >= 1
            assert 'Sri Ambati' in docs[0].page_content
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_html_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            html_content = """
<!DOCTYPE html>
<html>
<body>

<h1>Yugu is a wonderful place</h1>

<p>Animals love to run in the world of Yugu.  They play all day long in the alien sun.</p>

</body>
</html>
"""
            test_file1 = os.path.join(tmp_user_path, 'test.html')
            with open(test_file1, "wt") as f:
                f.write(html_content)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("Yugu")
            assert len(docs) >= 1
            assert 'Yugu' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_docx_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://calibre-ebook.com/downloads/demos/demo.docx'
            test_file1 = os.path.join(tmp_user_path, 'demo.docx')
            download_simple(url, dest=test_file1)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type)
            assert db is not None
            docs = db.similarity_search("What is calibre DOCX plugin do?")
            assert len(docs) >= 1
            assert 'calibre' in docs[0].page_content or 'an arrow pointing' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1) or \
                   'image' in os.path.normpath(docs[0].metadata['source'])
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_docx_add2(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            shutil.copy('tests/table_as_image.docx', tmp_user_path)
            test_file1 = os.path.join(tmp_user_path, 'demo.docx')
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               llava_model=os.getenv('H2OGPT_LLAVA_MODEL'),
                                               enable_doctr=True,
                                               )
            assert db is not None
            docs = db.similarity_search("Approver 1", k=4)
            assert len(docs) >= 1
            assert 'Band D' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(
                test_file1) or 'image1.png' in os.path.normpath(docs[0].metadata['source'])
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_xls_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            test_file1 = os.path.join(tmp_user_path, 'example.xlsx')
            shutil.copy('data/example.xlsx', tmp_user_path)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type)
            assert db is not None
            docs = db.similarity_search("What is Profit?")
            assert len(docs) >= 1
            assert '16185' in docs[0].page_content or \
                   'Small Business' in docs[0].page_content or \
                   'United States of America' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_md_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            test_file1 = 'README.md'
            if not os.path.isfile(test_file1):
                # see if ran from tests directory
                test_file1 = '../README.md'
                test_file1 = os.path.abspath(test_file1)
            shutil.copy(test_file1, tmp_user_path)
            test_file1 = os.path.join(tmp_user_path, os.path.basename(test_file1))
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type)
            assert db is not None
            docs = db.similarity_search("What is h2oGPT?")
            assert len(docs) >= 1
            assert 'Query and summarize your documents' in docs[1].page_content or 'document Q/A' in docs[
                1].page_content or 'go to your browser by visiting' in docs[1].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_rst_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://gist.githubusercontent.com/javiertejero/4585196/raw/21786e2145c0cc0a202ffc4f257f99c26985eaea/README.rst'
            test_file1 = os.path.join(tmp_user_path, 'demo.rst')
            download_simple(url, dest=test_file1)
            test_file1 = os.path.join(tmp_user_path, os.path.basename(test_file1))
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type)
            assert db is not None
            docs = db.similarity_search("Font Faces - Emphasis and Examples")
            assert len(docs) >= 1
            assert 'Within paragraphs, inline markup' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_xml_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://gist.githubusercontent.com/theresajayne/1409545/raw/a8b46e7799805e86f4339172c9778fa55afb0f30/gistfile1.txt'
            test_file1 = os.path.join(tmp_user_path, 'demo.xml')
            download_simple(url, dest=test_file1)
            test_file1 = os.path.join(tmp_user_path, os.path.basename(test_file1))
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type)
            assert db is not None
            docs = db.similarity_search("Entrance Hall")
            assert len(docs) >= 1
            assert 'Ensuite Bathroom' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_eml_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            test_file1 = os.path.join(tmp_user_path, 'sample.eml')
            shutil.copy('tests/sample.eml', test_file1)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("What is subject?")
            assert len(docs) >= 1
            assert 'testtest' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_simple_eml_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            html_content = """
Date: Sun, 1 Apr 2012 14:25:25 -0600
From: file@fyicenter.com
Subject: Welcome
To: someone@somewhere.com

Dear Friend,

Welcome to file.fyicenter.com!

Sincerely,
FYIcenter.com Team"""
            test_file1 = os.path.join(tmp_user_path, 'test.eml')
            with open(test_file1, "wt") as f:
                f.write(html_content)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("Subject")
            assert len(docs) >= 1
            assert 'Welcome' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_odt_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://github.com/owncloud/example-files/raw/master/Documents/Example.odt'
            test_file1 = os.path.join(tmp_user_path, 'sample.odt')
            download_simple(url, dest=test_file1)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type)
            assert db is not None
            docs = db.similarity_search("What is ownCloud?")
            assert len(docs) >= 1
            assert 'ownCloud' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_pptx_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://www.unm.edu/~unmvclib/powerpoint/pptexamples.ppt'
            test_file1 = os.path.join(tmp_user_path, 'sample.pptx')
            download_simple(url, dest=test_file1)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("Suggestions")
            assert len(docs) >= 1
            assert 'Presentation' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("use_pypdf", ['auto', 'on', 'off'])
@pytest.mark.parametrize("use_unstructured_pdf", ['auto', 'on', 'off'])
@pytest.mark.parametrize("use_pymupdf", ['auto', 'on', 'off'])
@pytest.mark.parametrize("enable_pdf_doctr", ['auto', 'on', 'off'])
@pytest.mark.parametrize("enable_pdf_ocr", ['auto', 'on', 'off'])
@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_pdf_add(db_type, enable_pdf_ocr, enable_pdf_doctr, use_pymupdf, use_unstructured_pdf, use_pypdf):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            if True:
                if False:
                    url = 'https://www.africau.edu/images/default/sample.pdf'
                    test_file1 = os.path.join(tmp_user_path, 'sample.pdf')
                    download_simple(url, dest=test_file1)
                else:
                    test_file1 = os.path.join(tmp_user_path, 'sample2.pdf')
                    shutil.copy(os.path.join('tests', 'sample.pdf'), test_file1)
            else:
                if False:
                    name = 'CityofTshwaneWater.pdf'
                    location = "tests"
                else:
                    name = '555_593.pdf'
                    location = '/home/jon/Downloads/'

                test_file1 = os.path.join(location, name)
                shutil.copy(test_file1, tmp_user_path)
                test_file1 = os.path.join(tmp_user_path, name)

            default_mode = use_pymupdf in ['auto', 'on'] and \
                           use_pypdf in ['auto'] and \
                           use_unstructured_pdf in ['auto'] and \
                           enable_pdf_doctr in ['off', 'auto'] and \
                           enable_pdf_ocr in ['off', 'auto']
            no_doc_mode = use_pymupdf in ['off'] and \
                          use_pypdf in ['off'] and \
                          use_unstructured_pdf in ['off'] and \
                          enable_pdf_doctr in ['off'] and \
                          enable_pdf_ocr in ['off', 'auto']

            try:
                db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                                   fail_any_exception=True, db_type=db_type,
                                                   use_pymupdf=use_pymupdf,
                                                   enable_pdf_ocr=enable_pdf_ocr,
                                                   enable_pdf_doctr=enable_pdf_doctr,
                                                   use_unstructured_pdf=use_unstructured_pdf,
                                                   use_pypdf=use_pypdf,
                                                   add_if_exists=False)
            except Exception as e:
                if 'had no valid text and no meta data was parsed' in str(
                        e) or 'had no valid text, but meta data was parsed' in str(e):
                    if no_doc_mode:
                        return
                    else:
                        raise
                raise

            assert db is not None
            docs = db.similarity_search("Suggestions")
            if default_mode:
                assert len(docs) >= 1
            else:
                # ocr etc. end up with different pages, overly complex to test exact count
                assert len(docs) >= 1
            assert 'And more text. And more text.' in docs[0].page_content
            if db_type == 'weaviate':
                assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1) or os.path.basename(
                    docs[0].metadata['source']) == os.path.basename(test_file1)
            else:
                assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("use_pypdf", ['auto', 'on', 'off'])
@pytest.mark.parametrize("use_unstructured_pdf", ['auto', 'on', 'off'])
@pytest.mark.parametrize("use_pymupdf", ['auto', 'on', 'off'])
@pytest.mark.parametrize("enable_pdf_doctr", ['auto', 'on', 'off'])
@pytest.mark.parametrize("enable_pdf_ocr", ['auto', 'on', 'off'])
@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_image_pdf_add(db_type, enable_pdf_ocr, enable_pdf_doctr, use_pymupdf, use_unstructured_pdf, use_pypdf):
    if enable_pdf_ocr == 'off' and not enable_pdf_doctr:
        return
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            name = 'CityofTshwaneWater.pdf'
            location = "tests"
            test_file1 = os.path.join(location, name)
            shutil.copy(test_file1, tmp_user_path)
            test_file1 = os.path.join(tmp_user_path, name)

            str_test = [db_type, enable_pdf_ocr, enable_pdf_doctr, use_pymupdf, use_unstructured_pdf, use_pypdf]
            str_test = [str(x) for x in str_test]
            str_test = '-'.join(str_test)

            default_mode = use_pymupdf in ['auto', 'on'] and \
                           use_pypdf in ['off', 'auto'] and \
                           use_unstructured_pdf in ['auto'] and \
                           enable_pdf_doctr in ['off', 'auto'] and \
                           enable_pdf_ocr in ['off', 'auto']
            no_doc_mode = use_pymupdf in ['off'] and \
                          use_pypdf in ['off'] and \
                          use_unstructured_pdf in ['off'] and \
                          enable_pdf_doctr in ['off'] and \
                          enable_pdf_ocr in ['off', 'auto']
            no_docs = ['off-off-auto-off-auto', 'off-off-on-off-on', 'off-off-auto-off-off', 'off-off-off-off-auto',
                       'off-off-on-off-off', 'off-off-on-off-auto', 'off-off-auto-off-on', 'off-off-off-off-on',

                       ]
            no_doc_mode |= any([x in str_test for x in no_docs])

            try:
                db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                                   fail_any_exception=True, db_type=db_type,
                                                   use_pymupdf=use_pymupdf,
                                                   enable_pdf_ocr=enable_pdf_ocr,
                                                   enable_pdf_doctr=enable_pdf_doctr,
                                                   use_unstructured_pdf=use_unstructured_pdf,
                                                   use_pypdf=use_pypdf,
                                                   add_if_exists=False)
            except Exception as e:
                if 'had no valid text and no meta data was parsed' in str(
                        e) or 'had no valid text, but meta data was parsed' in str(e):
                    if no_doc_mode:
                        return
                    else:
                        raise
                raise

            if default_mode:
                assert db is not None
                docs = db.similarity_search("List Tshwane's concerns about water.")
                assert len(docs) >= 1
                assert 'we appeal to residents that do have water to please use it sparingly.' in docs[
                    1].page_content or 'OFFICE OF THE MMC FOR UTILITIES AND REGIONAL' in docs[1].page_content
            else:

                assert db is not None
                docs = db.similarity_search("List Tshwane's concerns about water.")
                assert len(docs) >= 1
                assert docs[0].page_content
                assert docs[1].page_content
            if db_type == 'weaviate':
                assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1) or os.path.basename(
                    docs[0].metadata['source']) == os.path.basename(test_file1)
            else:
                assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_simple_pptx_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://www.suu.edu/webservices/styleguide/example-files/example.pptx'
            test_file1 = os.path.join(tmp_user_path, 'sample.pptx')
            download_simple(url, dest=test_file1)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("Example")
            assert len(docs) >= 1
            assert 'Powerpoint' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_epub_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://contentserver.adobe.com/store/books/GeographyofBliss_oneChapter.epub'
            test_file1 = os.path.join(tmp_user_path, 'sample.epub')
            download_simple(url, dest=test_file1)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("Grump")
            assert len(docs) >= 1
            assert 'happy' in docs[0].page_content or 'happiness' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.skip(reason="Not supported, GPL3, and msg-extractor code fails too often")
@pytest.mark.xfail(strict=False,
                   reason="fails with AttributeError: 'Message' object has no attribute '_MSGFile__stringEncoding'. Did you mean: '_MSGFile__overrideEncoding'? even though can use online converter to .eml fine.")
@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_msg_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'http://file.fyicenter.com/b/sample.msg'
            test_file1 = os.path.join(tmp_user_path, 'sample.msg')
            download_simple(url, dest=test_file1)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type)
            assert db is not None
            docs = db.similarity_search("Grump")
            assert len(docs) >= 1
            assert 'Happy' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


os.system('cd tests ; unzip -o driverslicense.jpeg.zip')


@pytest.mark.parametrize("file", ['data/pexels-evg-kowalievska-1170986_small.jpg',
                                  'data/Sample-Invoice-printable.png',
                                  'tests/driverslicense.jpeg.zip',
                                  'tests/driverslicense.jpeg'])
@pytest.mark.parametrize("db_type", db_types)
@pytest.mark.parametrize("enable_pix2struct", [False, True])
@pytest.mark.parametrize("enable_doctr", [False, True])
@pytest.mark.parametrize("enable_ocr", [False, True])
@pytest.mark.parametrize("enable_captions", [False, True])
@pytest.mark.parametrize("pre_load_image_audio_models", [False, True])
@pytest.mark.parametrize("caption_gpu", [False, True])
@pytest.mark.parametrize("captions_model", [None, 'microsoft/Florence-2-large'])
@wrap_test_forked
@pytest.mark.parallel10
def test_png_add(captions_model, caption_gpu, pre_load_image_audio_models, enable_captions,
                 enable_doctr, enable_pix2struct, enable_ocr, db_type, file):
    if not have_gpus and caption_gpu:
        # if have no GPUs, don't enable caption on GPU
        return
    if not caption_gpu and captions_model == 'microsoft/Florence-2-large':
        # RuntimeError: "slow_conv2d_cpu" not implemented for 'Half'
        return
    if not enable_captions and pre_load_image_audio_models:
        # nothing to preload if not enabling captions
        return
    if captions_model == 'microsoft/Florence-2-large' and not (have_gpus and mem_gpus[0] > 20 * 1024 ** 3):
        # requires GPUs and enough memory to run
        return
    if not (enable_ocr or enable_doctr or enable_pix2struct or enable_captions):
        # nothing enabled for images
        return
    # FIXME (too many permutations):
    if enable_pix2struct and (
            pre_load_image_audio_models or enable_captions or enable_ocr or enable_doctr or captions_model or caption_gpu):
        return
    if enable_pix2struct and 'kowalievska' in file:
        # FIXME: Not good for this
        return
    kill_weaviate(db_type)
    try:
        return run_png_add(captions_model=captions_model, caption_gpu=caption_gpu,
                           pre_load_image_audio_models=pre_load_image_audio_models,
                           enable_captions=enable_captions,
                           enable_ocr=enable_ocr,
                           enable_doctr=enable_doctr,
                           enable_pix2struct=enable_pix2struct,
                           db_type=db_type,
                           file=file)
    except Exception as e:
        if not enable_captions and 'data/pexels-evg-kowalievska-1170986_small.jpg' in file and 'had no valid text and no meta data was parsed' in str(
                e):
            pass
        else:
            raise
    kill_weaviate(db_type)


def run_png_add(captions_model=None, caption_gpu=False,
                pre_load_image_audio_models=False,
                enable_captions=True,
                enable_ocr=False,
                enable_doctr=False,
                enable_pix2struct=False,
                db_type='chroma',
                file='data/pexels-evg-kowalievska-1170986_small.jpg'):
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            test_file1 = file
            if not os.path.isfile(test_file1):
                # see if ran from tests directory
                test_file1 = os.path.join('../', file)
                assert os.path.isfile(test_file1)
            test_file1 = os.path.abspath(test_file1)
            shutil.copy(test_file1, tmp_user_path)
            test_file1 = os.path.join(tmp_user_path, os.path.basename(test_file1))
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True,
                                               enable_ocr=enable_ocr,
                                               enable_pdf_ocr='auto',
                                               enable_pdf_doctr=False,
                                               caption_gpu=caption_gpu,
                                               pre_load_image_audio_models=pre_load_image_audio_models,
                                               captions_model=captions_model,
                                               enable_captions=enable_captions,
                                               enable_doctr=enable_doctr,
                                               enable_pix2struct=enable_pix2struct,
                                               db_type=db_type,
                                               add_if_exists=False,
                                               fail_if_no_sources=False)
            if (enable_captions or enable_pix2struct) and not enable_doctr and not enable_ocr:
                if 'kowalievska' in file:
                    docs = db.similarity_search("cat", k=10)
                    assert len(docs) >= 1
                    assert 'cat sitting' in docs[0].page_content
                    check_source(docs, test_file1)
                elif 'Sample-Invoice-printable' in file:
                    docs = db.similarity_search("invoice", k=10)
                    assert len(docs) >= 1
                    # weak test
                    assert 'plumbing' in docs[0].page_content.lower() or 'invoice' in docs[0].page_content.lower()
                    check_source(docs, test_file1)
                else:
                    docs = db.similarity_search("license", k=10)
                    assert len(docs) >= 1
                    check_content_captions(docs, captions_model, enable_pix2struct)
                    check_source(docs, test_file1)
            elif not (enable_captions or enable_pix2struct) and not enable_doctr and enable_ocr:
                if 'kowalievska' in file:
                    assert db is None
                elif 'Sample-Invoice-printable' in file:
                    # weak test
                    assert db is not None
                else:
                    docs = db.similarity_search("license", k=10)
                    assert len(docs) >= 1
                    check_content_ocr(docs)
                    check_source(docs, test_file1)
            elif not (enable_captions or enable_pix2struct) and enable_doctr and not enable_ocr:
                if 'kowalievska' in file:
                    assert db is None
                elif 'Sample-Invoice-printable' in file:
                    # weak test
                    assert db is not None
                else:
                    docs = db.similarity_search("license", k=10)
                    assert len(docs) >= 1
                    check_content_doctr(docs)
                    check_source(docs, test_file1)
            elif not (enable_captions or enable_pix2struct) and enable_doctr and enable_ocr:
                if 'kowalievska' in file:
                    assert db is None
                elif 'Sample-Invoice-printable' in file:
                    # weak test
                    assert db is not None
                else:
                    docs = db.similarity_search("license", k=10)
                    assert len(docs) >= 1
                    check_content_doctr(docs)
                    check_content_ocr(docs)
                    check_source(docs, test_file1)
            elif (enable_captions or enable_pix2struct) and not enable_doctr and enable_ocr:
                if 'kowalievska' in file:
                    docs = db.similarity_search("cat", k=10)
                    assert len(docs) >= 1
                    assert 'cat sitting' in docs[0].page_content
                    check_source(docs, test_file1)
                elif 'Sample-Invoice-printable' in file:
                    # weak test
                    assert db is not None
                else:
                    docs = db.similarity_search("license", k=10)
                    assert len(docs) >= 1
                    check_content_ocr(docs)
                    check_content_captions(docs, captions_model, enable_pix2struct)
                    check_source(docs, test_file1)
            elif (enable_captions or enable_pix2struct) and enable_doctr and not enable_ocr:
                if 'kowalievska' in file:
                    docs = db.similarity_search("cat", k=10)
                    assert len(docs) >= 1
                    assert 'cat sitting' in docs[0].page_content
                    check_source(docs, test_file1)
                elif 'Sample-Invoice-printable' in file:
                    # weak test
                    assert db is not None
                else:
                    docs = db.similarity_search("license", k=10)
                    assert len(docs) >= 1
                    check_content_doctr(docs)
                    check_content_captions(docs, captions_model, enable_pix2struct)
                    check_source(docs, test_file1)
            elif (enable_captions or enable_pix2struct) and enable_doctr and enable_ocr:
                if 'kowalievska' in file:
                    docs = db.similarity_search("cat", k=10)
                    assert len(docs) >= 1
                    assert 'cat sitting' in docs[0].page_content
                    check_source(docs, test_file1)
                elif 'Sample-Invoice-printable' in file:
                    # weak test
                    assert db is not None
                else:
                    if db_type == 'chroma':
                        assert len(db.get()['documents']) >= 4
                    docs = db.similarity_search("license", k=10)
                    # because search can't find DRIVERLICENSE from DocTR one
                    assert len(docs) >= 1
                    check_content_ocr(docs)
                    # check_content_doctr(docs)
                    check_content_captions(docs, captions_model, enable_pix2struct)
                    check_source(docs, test_file1)
            else:
                raise NotImplementedError()


def check_content_captions(docs, captions_model, enable_pix2struct):
    assert any(['license' in docs[ix].page_content.lower() for ix in range(len(docs))])
    if captions_model is not None and 'florence' in captions_model:
        str_expected = """The image shows a California driver's license with a picture of a woman's face on it."""
        str_expected2 = """The image is a California driver's license."""
    elif enable_pix2struct:
        str_expected2 = str_expected = """california license"""
    else:
        str_expected = """The image shows a California driver's license with a picture of a woman's face on it."""
        str_expected2 = """The image is a California driver's license."""
    assert any([str_expected.lower() in docs[ix].page_content.lower() for ix in range(len(docs))]) or \
           any([str_expected2.lower() in docs[ix].page_content.lower() for ix in range(len(docs))])


def check_content_doctr(docs):
    assert any(['DRIVER LICENSE' in docs[ix].page_content for ix in range(len(docs))])
    assert any(['California' in docs[ix].page_content for ix in range(len(docs))])
    assert any(['ExP08/31/2014' in docs[ix].page_content for ix in range(len(docs))])
    assert any(['VETERAN' in docs[ix].page_content for ix in range(len(docs))])


def check_content_ocr(docs):
    # hi_res
    # assert any(['Californias' in docs[ix].page_content for ix in range(len(docs))])
    # ocr_only
    assert any(['DRIVER LICENSE' in docs[ix].page_content for ix in range(len(docs))])


def check_source(docs, test_file1):
    if test_file1.endswith('.zip'):
        # when zip, adds dir etc.:
        # AssertionError: assert '/tmp/tmp63h5dxxv/driverslicense.jpeg.zip_d7d5f561-6/driverslicense.jpeg' == '/tmp/tmp63h5dxxv/driverslicense.jpeg.zip'
        assert os.path.basename(os.path.normpath(test_file1)) in os.path.normpath(docs[0].metadata['source'])
    else:
        assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)


@pytest.mark.parametrize("image_file", ['./models/anthropic.png', 'data/pexels-evg-kowalievska-1170986_small.jpg'])
@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_caption_add(image_file, db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            file = os.path.basename(image_file)
            test_file1 = os.path.join(tmp_user_path, file)
            shutil.copy(image_file, test_file1)

            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False,
                                               enable_llava=True,
                                               llava_model=os.getenv('H2OGPT_LLAVA_MODEL'),
                                               llava_prompt=None,
                                               enable_doctr=False,
                                               enable_captions=False,
                                               enable_ocr=False,
                                               enable_transcriptions=False,
                                               enable_pdf_ocr=False,
                                               enable_pdf_doctr=False,
                                               enable_pix2struct=False,
                                               )
            assert db is not None
            if 'anthropic' in image_file:
                docs = db.similarity_search("circle")
                assert len(docs) >= 1
                assert 'AI' in docs[0].page_content
            else:
                docs = db.similarity_search("cat")
                assert len(docs) >= 1
                assert 'cat' in docs[0].page_content
                assert 'window' in docs[0].page_content or 'outdoors' in docs[0].page_content or 'outside' in docs[
                    0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_simple_rtf_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            rtf_content = """
{\rtf1\mac\deff2 {\fonttbl{\f0\fswiss Chicago;}{\f2\froman New York;}{\f3\fswiss Geneva;}{\f4\fmodern Monaco;}{\f11\fnil Cairo;}{\f13\fnil Zapf Dingbats;}{\f16\fnil Palatino;}{\f18\fnil Zapf Chancery;}{\f20\froman Times;}{\f21\fswiss Helvetica;}
{\f22\fmodern Courier;}{\f23\ftech Symbol;}{\f24\fnil Mobile;}{\f100\fnil FoxFont;}{\f107\fnil MathMeteor;}{\f164\fnil Futura;}{\f1024\fnil American Heritage;}{\f2001\fnil Arial;}{\f2005\fnil Courier New;}{\f2010\fnil Times New Roman;}
{\f2011\fnil Wingdings;}{\f2515\fnil MT Extra;}{\f3409\fnil FoxPrint;}{\f11132\fnil InsigniaLQmono;}{\f11133\fnil InsigniaLQprop;}{\f14974\fnil LB Helvetica Black;}{\f14976\fnil L Helvetica Light;}}{\colortbl\red0\green0\blue0;\red0\green0\blue255;
\red0\green255\blue255;\red0\green255\blue0;\red255\green0\blue255;\red255\green0\blue0;\red255\green255\blue0;\red255\green255\blue255;}{\stylesheet{\f4\fs18 \sbasedon222\snext0 Normal;}}{\info{\title samplepostscript.msw}{\author 
Computer Science Department}}\widowctrl\ftnbj \sectd \sbknone\linemod0\linex0\cols1\endnhere \pard\plain \qc \f4\fs18 {\plain \b\f21 Sample Rich Text Format Document\par 
}\pard {\plain \f20 \par 
}\pard \ri-80\sl-720\keep\keepn\absw570 {\caps\f20\fs92\dn6 T}{\plain \f20 \par 
}\pard \qj {\plain \f20 his is a sample rich text format (RTF), document. This document was created using Microsoft Word and then printing the document to a RTF file. It illustrates the very basic text formatting effects that can be achieved using RTF. 
\par 
\par 
}\pard \qj\li1440\ri1440\box\brdrs \shading1000 {\plain \f20 RTF }{\plain \b\f20 contains codes for producing advanced editing effects. Such as this indented, boxed, grayed background, entirely boldfaced paragraph.\par 
}\pard \qj {\plain \f20 \par 
Microsoft  Word developed RTF for document transportability and gives a user access to the complete set of the effects that can be achieved using RTF. \par 
}}
"""
            test_file1 = os.path.join(tmp_user_path, 'test.rtf')
            with open(test_file1, "wt") as f:
                f.write(rtf_content)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("How was this document created?")
            assert len(docs) >= 1
            assert 'Microsoft' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


# Windows is not supported with EmbeddedDB. Please upvote the feature request if you want this: https://github.com/weaviate/weaviate-python-client/issues/239
@pytest.mark.parametrize("db_type", ['chroma'])
@wrap_test_forked
def test_url_more_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        url = 'https://edition.cnn.com/2023/08/19/europe/ukraine-f-16s-counteroffensive-intl/index.html'
        db, collection_name = make_db_main(persist_directory=tmp_persist_directory, url=url, fail_any_exception=True,
                                           db_type=db_type)
        assert db is not None
        docs = db.similarity_search("Ukraine")
        assert len(docs) >= 1
        assert 'Ukraine' in docs[0].page_content
    kill_weaviate(db_type)


json_data = {
    "quiz": {
        "sport": {
            "q1": {
                "question": "Which one is correct team name in NBA?",
                "options": [
                    "New York Bulls",
                    "Los Angeles Kings",
                    "Golden State Warriros",
                    "Huston Rocket"
                ],
                "answer": "Huston Rocket"
            }
        },
        "maths": {
            "q1": {
                "question": "5 + 7 = ?",
                "options": [
                    "10",
                    "11",
                    "12",
                    "13"
                ],
                "answer": "12"
            },
            "q2": {
                "question": "12 - 8 = ?",
                "options": [
                    "1",
                    "2",
                    "3",
                    "4"
                ],
                "answer": "4"
            }
        }
    }
}


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_json_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            # too slow:
            # eval_filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
            # url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s" % eval_filename
            test_file1 = os.path.join(tmp_user_path, 'sample.json')
            # download_simple(url, dest=test_file1)

            with open(test_file1, 'wt') as f:
                f.write(json.dumps(json_data))

            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("NBA")
            assert len(docs) >= 1
            assert 'Bulls' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1)
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_jsonl_gz_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            # url = "https://huggingface.co/datasets/OpenAssistant/oasst1/resolve/main/2023-04-12_oasst_spam.messages.jsonl.gz"
            test_file1 = os.path.join(tmp_user_path, 'sample.jsonl.gz')
            # download_simple(url, dest=test_file1)

            with gzip.open(test_file1, 'wb') as f:
                f.write(json.dumps(json_data).encode())

            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("NBA")
            assert len(docs) >= 1
            assert 'Bulls' in docs[0].page_content
            assert os.path.normpath(docs[0].metadata['source']) == os.path.normpath(test_file1).replace('.gz', '')
    kill_weaviate(db_type)


@wrap_test_forked
def test_url_more_subunit():
    url = 'https://edition.cnn.com/2023/08/19/europe/ukraine-f-16s-counteroffensive-intl/index.html'
    from langchain.document_loaders import UnstructuredURLLoader
    docs1 = UnstructuredURLLoader(urls=[url]).load()
    docs1 = [x for x in docs1 if x.page_content]
    assert len(docs1) > 0

    # Playwright and Selenium fails on cnn url
    url_easy = 'https://github.com/h2oai/h2ogpt'

    from langchain.document_loaders import PlaywrightURLLoader
    docs1 = PlaywrightURLLoader(urls=[url_easy]).load()
    docs1 = [x for x in docs1 if x.page_content]
    assert len(docs1) > 0

    from langchain.document_loaders import SeleniumURLLoader
    docs1 = SeleniumURLLoader(urls=[url_easy]).load()
    docs1 = [x for x in docs1 if x.page_content]
    assert len(docs1) > 0


@wrap_test_forked
@pytest.mark.parametrize("db_type", db_types_full)
@pytest.mark.parametrize("num", [1000, 100000])
def test_many_text(db_type, num):
    from langchain.docstore.document import Document

    sources = [Document(page_content=str(i)) for i in range(0, num)]
    hf_embedding_model = "fake"
    # hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
    # hf_embedding_model = 'BAAI/bge-large-en-v1.5'
    db = get_db(sources, db_type=db_type, langchain_mode='ManyTextData', hf_embedding_model=hf_embedding_model)
    documents = get_documents(db)['documents']
    assert len(documents) == num


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_youtube_audio_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://www.youtube.com/watch?v=cwjs1WAG9CM'
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, url=url,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False,
                                               extract_frames=0)
            assert db is not None
            docs = db.similarity_search("Example")
            assert len(docs) >= 1
            assert 'Contrasting this' in docs[0].page_content
            assert url in docs[0].metadata['source']
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_youtube_full_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://www.youtube.com/shorts/JjdqlglRxrU'
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, url=url,
                                               fail_any_exception=True, db_type=db_type,
                                               add_if_exists=False)
            assert db is not None
            docs = db.similarity_search("cat")
            assert len(docs) >= 1
            assert 'couch' in str([x.page_content for x in docs])
            assert url in docs[0].metadata['source'] or url in docs[0].metadata['original_source']
            docs = db.similarity_search("cat", 100)
            assert 'egg' in str([x.page_content for x in docs])
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_mp3_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            test_file1 = os.path.join(tmp_user_path, 'sample.mp3.zip')
            shutil.copy('tests/porsche.mp3.zip', test_file1)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type)
            assert db is not None
            docs = db.similarity_search("Porsche")
            assert len(docs) >= 1
            assert 'Porsche Macan' in docs[0].page_content
            assert 'porsche.mp3' in os.path.normpath(docs[0].metadata['source'])
    kill_weaviate(db_type)


@pytest.mark.parametrize("db_type", db_types)
@wrap_test_forked
def test_mp4_add(db_type):
    kill_weaviate(db_type)
    from src.make_db import make_db_main
    with tempfile.TemporaryDirectory() as tmp_persist_directory:
        with tempfile.TemporaryDirectory() as tmp_user_path:
            url = 'https://h2o-release.s3.amazonaws.com/h2ogpt/iG_jeMeUPBnUO6sx.mp4'
            test_file1 = os.path.join(tmp_user_path, 'demo.mp4')
            download_simple(url, dest=test_file1)
            db, collection_name = make_db_main(persist_directory=tmp_persist_directory, user_path=tmp_user_path,
                                               fail_any_exception=True, db_type=db_type,
                                               enable_captions=True)
            assert db is not None
            docs = db.similarity_search("Gemini")
            assert len(docs) >= 1
            assert 'Gemini' in str([x.page_content for x in docs])
            assert 'demo.mp4' in os.path.normpath(docs[0].metadata['source'])
            docs = db.similarity_search("AI", 100)
            assert 'fun birthday party' in str([x.page_content for x in docs])
            assert 'Gemini tries to design' in str([x.page_content for x in docs])
            assert 'H2OAudioCaptionLoader' in str([x.metadata for x in docs])
            assert 'H2OImageCaptionLoader' in str([x.metadata for x in docs])
            assert '.jpg' in str([x.metadata for x in docs])
    kill_weaviate(db_type)


@wrap_test_forked
def test_chroma_filtering():
    # get test model so don't have to reload it each time
    model, tokenizer, base_model, prompt_type = get_test_model()

    # generic settings true for all cases
    requests_state1 = {'username': 'foo'}
    verbose1 = True
    max_raw_chunks = None
    api = False
    n_jobs = -1
    db_type1 = 'chroma'
    load_db_if_exists1 = True
    use_openai_embedding1 = False
    migrate_embedding_model_or_db1 = False

    def get_userid_auth_fake(requests_state1, auth_filename=None, auth_access=None, guest_name=None, **kwargs):
        return str(uuid.uuid4())

    other_kwargs = dict(load_db_if_exists1=load_db_if_exists1,
                        db_type1=db_type1,
                        use_openai_embedding1=use_openai_embedding1,
                        migrate_embedding_model_or_db1=migrate_embedding_model_or_db1,
                        verbose1=verbose1,
                        get_userid_auth1=get_userid_auth_fake,
                        max_raw_chunks=max_raw_chunks,
                        api=api,
                        n_jobs=n_jobs,
                        enforce_h2ogpt_api_key=False,
                        enforce_h2ogpt_ui_key=False,
                        )
    mydata_mode1 = LangChainMode.MY_DATA.value
    from src.make_db import make_db_main

    for chroma_new in [True]:
        print("chroma_new: %s" % chroma_new, flush=True)
        if chroma_new:
            # fresh, so chroma >= 0.4
            user_path = make_user_path_test()
            from langchain_community.vectorstores import Chroma
            db, collection_name = make_db_main(user_path=user_path)
            assert isinstance(db, Chroma)

            hf_embedding_model = 'hkunlp/instructor-xl'
            langchain_mode1 = collection_name
            query = 'What is h2oGPT?'
        else:
            raise RuntimeError("Migration no longer supported")

        db1s = {langchain_mode1: [None] * length_db1(), mydata_mode1: [None] * length_db1()}

        dbs1 = {langchain_mode1: db}
        langchain_modes = [langchain_mode1]
        langchain_mode_paths = dict(langchain_mode1=None)
        langchain_mode_types = dict(langchain_modes='shared')
        selection_docs_state1 = dict(langchain_modes=langchain_modes,
                                     langchain_mode_paths=langchain_mode_paths,
                                     langchain_mode_types=langchain_mode_types)

        run_db_kwargs = dict(query=query,
                             db=db,
                             use_openai_model=False, use_openai_embedding=False, text_limit=None,
                             hf_embedding_model=hf_embedding_model,
                             db_type=db_type1,
                             langchain_mode_paths=langchain_mode_paths,
                             langchain_mode_types=langchain_mode_types,
                             langchain_mode=langchain_mode1,
                             langchain_agents=[],
                             llamacpp_dict={},

                             model=model,
                             tokenizer=tokenizer,
                             model_name=base_model,
                             prompt_type=prompt_type,

                             top_k_docs=10,  # 4 leaves out docs for test in some cases, so use 10
                             cut_distance=1.8,  # default leaves out some docs in some cases
                             )

        # GET_CHAIN etc.
        for answer_with_sources in [-1, True]:
            print("answer_with_sources: %s" % answer_with_sources, flush=True)
            # mimic nochat-API or chat-UI
            append_sources_to_answer = answer_with_sources != -1
            for doc_choice in ['All', 1, 2]:
                if doc_choice == 'All':
                    document_choice = [DocumentChoice.ALL.value]
                else:
                    docs = [x['source'] for x in db.get()['metadatas']]
                    if doc_choice == 1:
                        document_choice = docs[:doc_choice]
                    else:
                        # ensure don't get dup
                        docs = sorted(set(docs))
                        document_choice = docs[:doc_choice]
                print("doc_choice: %s" % doc_choice, flush=True)
                for langchain_action in [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value]:
                    print("langchain_action: %s" % langchain_action, flush=True)
                    for document_subset in [DocumentSubset.Relevant.name, DocumentSubset.TopKSources.name,
                                            DocumentSubset.RelSources.name]:
                        print("document_subset: %s" % document_subset, flush=True)

                        ret = _run_qa_db(**run_db_kwargs,
                                         langchain_action=langchain_action,
                                         document_subset=document_subset,
                                         document_choice=document_choice,
                                         answer_with_sources=answer_with_sources,
                                         append_sources_to_answer=append_sources_to_answer,
                                         )
                        rets = check_ret(ret)
                        rets1 = rets[0]
                        if chroma_new:
                            if answer_with_sources == -1:
                                assert len(rets1) >= 7 and (
                                        'h2oGPT' in rets1['response'] or 'H2O GPT' in rets1['response'] or 'H2O.ai' in
                                        rets1['response'])
                            else:
                                assert len(rets1) >= 7 and (
                                        'h2oGPT' in rets1['response'] or 'H2O GPT' in rets1['response'] or 'H2O.ai' in
                                        rets1['response'])
                                if document_subset == DocumentSubset.Relevant.name:
                                    assert 'h2oGPT' in str(rets1['sources'])
                        else:
                            if answer_with_sources == -1:
                                assert len(rets1) >= 7 and (
                                        'whisper' in rets1['response'].lower() or
                                        'phase' in rets1['response'].lower() or
                                        'generate' in rets1['response'].lower() or
                                        'statistic' in rets1['response'].lower() or
                                        'a chat bot that' in rets1['response'].lower() or
                                        'non-centrality parameter' in rets1['response'].lower() or
                                        '.pdf' in rets1['response'].lower() or
                                        'gravitational' in rets1['response'].lower() or
                                        'answer to the question' in rets1['response'].lower() or
                                        'not responsible' in rets1['response'].lower()
                                )
                            else:
                                assert len(rets1) >= 7 and (
                                        'whisper' in rets1['response'].lower() or
                                        'phase' in rets1['response'].lower() or
                                        'generate' in rets1['response'].lower() or
                                        'statistic' in rets1['response'].lower() or
                                        '.pdf' in rets1['response'].lower())
                                if document_subset == DocumentSubset.Relevant.name:
                                    assert 'whisper' in str(rets1['sources']) or \
                                           'unbiased' in str(rets1['sources']) or \
                                           'approximate' in str(rets1['sources'])
                        if answer_with_sources == -1:
                            if document_subset == DocumentSubset.Relevant.name:
                                assert 'score' in rets1['sources'][0] and 'content' in rets1['sources'][
                                    0] and 'source' in rets1['sources'][0]
                                if doc_choice in [1, 2]:
                                    if langchain_action == 'Summarize':
                                        assert len(set(flatten_list([x['source'].split(docs_joiner_default) for x in
                                                                     rets1['sources']]))) >= doc_choice
                                    else:
                                        assert len(set([x['source'] for x in rets1['sources']])) >= 1
                                else:
                                    assert len(set([x['source'] for x in rets1['sources']])) >= 1
                            elif document_subset == DocumentSubset.RelSources.name:
                                if doc_choice in [1, 2]:
                                    assert len(set([x['source'] for x in rets1['sources']])) <= doc_choice
                                else:
                                    if langchain_action == 'Summarize':
                                        assert len(set(flatten_list(
                                            [x['source'].split(docs_joiner_default) for x in rets1['sources']]))) >= 1
                                    else:
                                        assert len(set([x['source'] for x in rets1['sources']])) >= 1
                            else:
                                # TopK may just be 1 doc because of many chunks from that doc
                                # if top_k_docs=-1 might get more
                                assert len(set([x['source'] for x in rets1['sources']])) >= 1

        # SHOW DOC
        single_document_choice1 = [x['source'] for x in db.get()['metadatas']][0]
        text_context_list1 = []
        pdf_height = 800
        h2ogpt_key1 = ''
        for view_raw_text_checkbox1 in [True, False]:
            print("view_raw_text_checkbox1: %s" % view_raw_text_checkbox1, flush=True)
            from src.gradio_runner import show_doc
            show_ret = show_doc(db1s, selection_docs_state1, requests_state1,
                                langchain_mode1,
                                single_document_choice1,
                                view_raw_text_checkbox1,
                                text_context_list1,
                                pdf_height,
                                h2ogpt_key1,
                                dbs1=dbs1,
                                hf_embedding_model1=hf_embedding_model,
                                **other_kwargs
                                )
            assert len(show_ret) == 8
            if chroma_new:
                assert1 = show_ret[4]['value'] is not None and 'README.md' in show_ret[4]['value']
                assert2 = show_ret[3]['value'] is not None and 'h2oGPT' in show_ret[3]['value']
                assert assert1 or assert2
            else:
                assert1 = show_ret[4]['value'] is not None and single_document_choice1 in show_ret[4]['value']
                assert2 = show_ret[3]['value'] is not None and single_document_choice1 in show_ret[3]['value']
                assert assert1 or assert2


@pytest.mark.parametrize("max_input_tokens", [
    1024, None
])
@pytest.mark.parametrize("data_kind", [
    'simple',
    'helium1',
    'helium2',
    'helium3',
    'helium4',
    'helium5',
    'long',
    'very_long',
])
@wrap_test_forked
def test_merge_docs(data_kind, max_input_tokens):
    t0 = time.time()

    model_max_length = 4096
    if max_input_tokens is None:
        max_input_tokens = model_max_length - 512
    docs_joiner = docs_joiner_default
    docs_token_handling = docs_token_handling_default
    tokenizer = FakeTokenizer(model_max_length=model_max_length, is_super_fake=True)

    from langchain.docstore.document import Document
    if data_kind == 'simple':
        texts = texts_simple
    elif data_kind == 'helium1':
        texts = texts_helium1
    elif data_kind == 'helium2':
        texts = texts_helium2
    elif data_kind == 'helium3':
        texts = texts_helium3
    elif data_kind == 'helium4':
        texts = texts_helium4
    elif data_kind == 'helium5':
        texts = texts_helium5
    elif data_kind == 'long':
        texts = texts_long
    elif data_kind == 'very_long':
        texts = ['\n'.join(texts_long * 100)]
    else:
        raise RuntimeError("BAD")

    docs_with_score = [(Document(page_content=page_content, metadata={"source": "%d" % pi}), 1.0) for pi, page_content
                       in enumerate(texts)]

    docs_with_score_new, max_docs_tokens = (
        split_merge_docs(docs_with_score, tokenizer=tokenizer, max_input_tokens=max_input_tokens,
                         docs_token_handling=docs_token_handling, joiner=docs_joiner, verbose=True))

    text_context_list = [x[0].page_content for x in docs_with_score_new]
    tokens = [get_token_count(x + docs_joiner, tokenizer) for x in text_context_list]
    print(tokens)

    if data_kind == 'simple':
        assert len(docs_with_score_new) == 1
        assert all([x <= max_input_tokens for x in tokens])
        assert time.time() - t0 < 0.1
    elif data_kind == 'helium1':
        assert len(docs_with_score_new) == 4 if max_input_tokens == 1024 else 2, len(docs_with_score_new)
        assert all([x <= max_input_tokens for x in tokens])
        assert time.time() - t0 < 0.1
    elif data_kind == 'helium2':
        assert len(docs_with_score_new) == 7 if max_input_tokens == 1024 else 3, len(docs_with_score_new)
        assert all([x <= max_input_tokens for x in tokens])
        assert time.time() - t0 < 0.1
    elif data_kind == 'helium3':
        assert len(docs_with_score_new) == 6 if max_input_tokens == 1024 else 2, len(docs_with_score_new)
        assert all([x <= max_input_tokens for x in tokens])
        assert time.time() - t0 < 0.1
    elif data_kind == 'helium4':
        assert len(docs_with_score_new) == 6 if max_input_tokens == 1024 else 2, len(docs_with_score_new)
        assert all([x <= max_input_tokens for x in tokens])
        assert time.time() - t0 < 0.1
    elif data_kind == 'helium5':
        assert len(docs_with_score_new) == 6 if max_input_tokens == 1024 else 1, len(docs_with_score_new)
        assert all([x <= max_input_tokens for x in tokens])
        assert time.time() - t0 < 0.1
    elif data_kind == 'long':
        assert len(docs_with_score_new) == 47 if max_input_tokens == 1024 else 6, len(docs_with_score_new)
        assert all([x <= max_input_tokens for x in tokens])
        assert time.time() - t0 < 0.1
    elif data_kind == 'very_long':
        assert len(docs_with_score_new) == 4601 if max_input_tokens == 1024 else 6, len(docs_with_score_new)
        assert all([x <= max_input_tokens for x in tokens])
        if max_input_tokens == 1024:
            assert time.time() - t0 < 60
        else:
            assert time.time() - t0 < 10
    print("duration: %s" % (time.time() - t0), flush=True)


@wrap_test_forked
def test_split_and_merge():
    kwargs = {'max_input_tokens': 7118, 'docs_token_handling': 'split_or_merge', 'joiner': '\n\n',
              'non_doc_prompt': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nGive a summary that is well-structured yet concise.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"""\n\n"""\nWrite a summary for a physics Ph.D. and assistant professor in physics doing astrophysics, identifying key points of interest.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n',
              'verbose': False}
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')
    from langchain_core.documents import Document
    docs_with_score = [(Document(page_content=page_content, metadata={"source": "%d" % pi}), 1.0) for pi, page_content
                       in enumerate(texts_long)]

    docs_with_score, max_doc_tokens = split_merge_docs(docs_with_score,
                                                       tokenizer,
                                                       **kwargs)
    assert len(docs_with_score) == 6
    # ensure docuemnt doesn't start with . from sentence splitting
    assert docs_with_score[0][0].page_content.startswith('Y')


@wrap_test_forked
def test_crawl():
    from src.gpt_langchain import Crawler
    final_urls = Crawler(urls=['https://github.com/h2oai/h2ogpt'], verbose=True).run()
    assert 'https://github.com/h2oai/h2ogpt/blob/main/docs/README_GPU.md' in final_urls
    print(final_urls)


@wrap_test_forked
def test_hyde_acc():
    answer = 'answer'
    llm_answers = dict(response_raw='raw')
    hyde_show_intermediate_in_accordion = False
    map_reduce_show_intermediate_in_accordion = False
    answer, hyde = get_hyde_acc(answer, llm_answers, hyde_show_intermediate_in_accordion,
                                map_reduce_show_intermediate_in_accordion)
    assert hyde == ''

    answer = ['answer']
    llm_answers = dict(response_raw='raw')
    hyde_show_intermediate_in_accordion = False
    map_reduce_show_intermediate_in_accordion = False
    answer, hyde = get_hyde_acc(answer, llm_answers, hyde_show_intermediate_in_accordion,
                                map_reduce_show_intermediate_in_accordion)
    assert hyde is None


if __name__ == '__main__':
    pass
